import numpy as np


class GridRunEnv():
    def __init__(self):
        super(GridRunEnv, self).__init__()

        self.grid_size = 8
        # self.action_space = spaces.Discrete(4)  # Up, down, left, right
        # self.observation_space = spaces.Tuple((
        #     spaces.Discrete(self.grid_size),  # X position
        #     spaces.Discrete(self.grid_size),  # Y position
        #     spaces.Discrete(5)  # Tile type
        # ))

        self.init_state = (7, 0, 0)  # Starting at bottom-left corner, dry
        self.wind_probability = 1/32

        # Define tile types
        self.lava_tiles = [(0, 2), (0, 3), (0, 6), (1, 6), (3, 6), (4, 6), (5, 6), (6, 6), (7, 6)]
        self.water_tiles = [(4, 2), (4, 3), (4, 4), (4, 5), (6, 2), (6, 3), (6, 4), (6, 5)]
        self.recharge_tiles = [(0, 7), (7, 0), (5, 7)]
        self.dry_tiles = [(2, 0), (7, 7)]

    def step(self, action):
        x, y, wet = self.state

        if action == 0 and y < self.grid_size - 1:  # Up
            y += 1
        elif action == 1 and y > 0:  # Down
            y -= 1
        elif action == 2 and x > 0:  # Left
            x -= 1
        elif action == 3 and x < self.grid_size - 1:  # Right
            x += 1

        # Apply wind effect
        if np.random.rand() < self.wind_probability:
            y = max(0, y - 1)  # Push down

        if (x, y) in self.water_tiles:
            wet = 1
        
        if wet == 1 and (x, y) in self.dry_tiles:
            wet = 0

        self.state = (x, y, wet)

        reward = 0
        if (x, y) in self.lava_tiles:
            reward = -100
        elif wet == 0 and (x, y) in self.recharge_tiles:
            reward = 100

        done = (x, y) in self.lava_tiles or (wet == 0 and (x, y) in self.recharge_tiles)

        return self._get_obs(), reward, done, {}

    def reset(self):
        self.state = self.init_state[:]
        return self._get_obs()

    def _get_obs(self):
        tile_type = 0
        if self.state[0:2] in self.lava_tiles:
            tile_type = 1
        elif self.state[0:2] in self.water_tiles:
            tile_type = 2
        elif self.state[0:2] in self.recharge_tiles:
            tile_type = 3
        elif self.state[0:2] in self.dry_tiles:
            tile_type = 4
        return self.state[0], self.state[1], tile_type
    
    def get_safe_action(self):
        if self.state[1] < self.grid_size - 1 and (self.state[0], self.state[1] + 1) not in self.lava_tiles:
            return 0
        elif self.state[1] > 0 and (self.state[0], self.state[1] - 1) not in self.lava_tiles:
            return 1
        elif self.state[0] > 0 and (self.state[0] - 1, self.state[1]) not in self.lava_tiles:
            return 2
        elif self.state[0] < self.grid_size - 1 and (self.state[0] + 1, self.state[1]) not in self.lava_tiles:
            return 3
        return None
